# -*- encoding: utf-8 -*-

"""
@File:      basetest.py
@Time:      2022/05/12 16:18:47
@Author:    bolong.tbl
@Version:   1.0
@Contact:   bolong.tbl@alibaba-inc.com
@License:   Mulan PSL v2
"""

import time
from enum import Enum
from avocado import Test
from common.hosts import LocalHost, RemoteHost
from avocado.core import exceptions
from avocado.utils import ssh
from common.container import Container
from common.os_env import OSType, Env, Image

class BaseTest(Test):
    """
    BaseTest class for all testcase
    """

    def setUp(self):
        self.log.info('setup')

    def tearDown(self):
        self.log.info('teardown')

    def skip(self, message=None):
        raise exceptions.TestSkipError(message)

    def _get_image_info(self):
        image = Image()
        ret_c, output = self.cmd('find /etc -name "image-id"')
        if output:
            image.env = Env.ECS
            ret_c, output = self.cmd('cat /etc/image-id | grep image_id')
            image.id = output.split('=')[1].strip('"')
        else:
            image.env = Env.PHY
        ret_c,output = self.cmd('cat /etc/os-release')
        for line in output.split('\n'):
            if line.startswith('ID='):
                image.ostype = OSType[line.split('=')[1].strip('"').upper()]
            elif line.startswith('VERSION_ID='):
                image.version = line.split('=')[1].strip('"')
        ret_c, image.arch = self.cmd('arch')
        ret_c, image.kernel = self.cmd('uname -r')
        self.log.debug('image.env: {}, image.ostype: {}, image.version: {}, image.arch: {}, image.kernel: {}, image.id: {}'.format(
            image.env, image.ostype, image.version, image.arch, image.kernel, image.id))
        return image

    def skip_non_root_test(self,container_flag=1):
        ret_c, root_id = self.cmd("id",container_flag=1)
        if "root" not in root_id:
            self.skip("The current user is not root")

class LocalTest(BaseTest):
    """
    LocalTest class represents a testcase running on localhost

    Tips: if you pass remote host parameters to the testcase,
          you can also run the testcase on remote host
    """
    RPM_FLAG = {}
    def setUp(self, param_dic={}):
        super().setUp()
        self.local = LocalHost()
        self.remote = None
        if self.params.get('remote'):
            self.remote = RemoteHost(host=self.params.get('remote'),
                                     username=self.params.get('username'),
                                     port=self.params.get('port', default=22),
                                     key=self.params.get('key', default=None),
                                     password=self.params.get('password'))
        self.container_id = None
        self.version = self.params.get('registry')
        self.image = self._get_image_info()
        self.pkg_name = param_dic.get("pkg_name")
        if self.params.get('engine') is not None:
            self.container_engine = self.params.get('engine')
            self.container_id = Container.create_container(self,self.container_engine,self.image,self.version,self.container_id)
        if self.pkg_name:
            self.RPM_FLAG = self.setup_rpm_install(self.pkg_name)

    def cmd(self, command, host="local", ignore_status=False,container_flag=0):
        if container_flag and self.container_id:
            command = self.container_engine+' exec '+self.container_id+' bash -c \''+command+'\''
        if self.remote and host == "remote":
            return self.remote.cmd(command, ignore_status=ignore_status)
        return self.local.cmd(command, ignore_status=ignore_status)

    def tearDown(self, param_dic={}):
        super().tearDown()
        if self.pkg_name:
            self.rpm_uninstall(self.RPM_FLAG)
        if self.container_id is not None:
            Container.destroy_container(self,self.container_engine,self.container_id,self.version,self.image)

    def setup_rpm_install(self, pkg_name, flag=0):
        rpm_flag = {}
        rpms = pkg_name.split()
        _, os_type = self.cmd('cat /etc/os-release | grep -w NAME= | awk -F "\\\"" "{print \$2}"', container_flag=flag)
        if os_type == "Ubuntu":
            ret_c, _ = self.cmd("apt-get update", ignore_status=True, container_flag=flag)
            pkg_status_cmd = "dpkg -s "
            pkg_info_cmd = "apt show "
            pkg_install_cmd = "apt-get install -y "
            pkg_update_cmd = "apt-get upgrade -y "
        else:
            pkg_status_cmd = "rpm -q "
            pkg_info_cmd = "yum info "
            pkg_install_cmd = "yum install -y "
            pkg_update_cmd = "yum update -y "
        for rpm in rpms:
            cmdline = pkg_status_cmd + rpm
            ret_c, _ = self.cmd(cmdline, ignore_status=True, container_flag=flag)
            if ret_c is not 0:
                cmdline = pkg_info_cmd + rpm
                ret_c, _ = self.cmd(cmdline, ignore_status=True, container_flag=flag)
                if ret_c is not 0:
                    self.skip("%s is not available." % rpm)
                cmdline = pkg_install_cmd + rpm
                self.cmd(cmdline, container_flag=flag)
                rpm_flag[rpm] = 1
            else:
                cmdline = pkg_update_cmd + rpm
                self.cmd(cmdline, container_flag=flag)
                rpm_flag[rpm] = 0
        return rpm_flag

    def rpm_uninstall(self, rpm_flag, flag=0):
        _, os_type = self.cmd('cat /etc/os-release | grep -w NAME= | awk -F "\\\"" "{print \$2}"', container_flag=flag)
        if os_type == "Ubuntu":
            pkg_status_cmd = "apt-get remove -y "
        else:
            pkg_status_cmd = "yum erase -y "
        rpms = ""
        for rpm in rpm_flag.keys():
            if rpm_flag[rpm] is 1:
                rpms = rpms + " " + rpm
        if rpms != "":
            cmdline = pkg_status_cmd + rpm
            self.cmd(cmdline)


class RemoteTest(BaseTest):
    """
    RemoteTest class represents a testcase running on remote host

    :param remote (str): remote host ip or hostname
    :param username (str): remote host username
    :param port (int): remote host port
    :param key (str): remote host key
    :param password (str): remote host password
    """
    
    RPM_FLAG = {}
    def setUp(self, param_dic={}):
        super().setUp()
        self.remote = RemoteHost(host=self.params.get('remote'),
                                 username=self.params.get('username'),
                                 port=self.params.get('port', default=22),
                                 key=self.params.get('key', default=None),
                                 password=self.params.get('password'))
        self.container_id = None
        self.version = self.params.get('registry')
        if self.remote.host and self.remote.host != 'localhost':
            self.session = ssh.Session(self.remote.host, user=self.remote.username, password=self.remote.password)
            self.session.connect()
        self.local = LocalHost()
        self.image = self._get_image_info()
        if self.params.get('engine') is not None:
            self.container_engine = self.params.get('engine')
            self.container_id = Container.create_container(self,self.container_engine,self.image,self.version,self.container_id)
        self.pkg_name = param_dic.get("pkg_name")
        if self.pkg_name:
            self.RPM_FLAG = self.setup_rpm_install(self.pkg_name)

    def cmd(self, command, host="remote", ignore_status=False,container_flag=0):
        if container_flag and self.container_id:
            command = self.container_engine+' exec '+self.container_id+' bash -c \''+command+'\''
        if host == 'remote'and self.remote.host != None:
            return self.remote.cmd(command, ignore_status=ignore_status)
        else:
            return self.local.cmd(command, ignore_status=ignore_status)

    def wait_ssh_connect(self, timeout=600):
        time.sleep(1)
        while timeout > 0:
            result = self.remote.session.connect()
            if result:
                self.log.info('ssh connect successfully')
                break
            else:
                time.sleep(1)
                timeout -= 1

    def tearDown(self, param_dic={}):
        super().tearDown()
        if self.pkg_name:
            self.rpm_uninstall(self.RPM_FLAG)
        if self.container_id is not None:
            Container.destroy_container(self,self.container_engine,self.container_id,self.version,self.image)

    def setup_rpm_install(self, pkg_name, flag=0):
        rpm_flag = {}
        rpms = pkg_name.split()
        _, os_type = self.cmd('cat /etc/os-release | grep -w NAME= | awk -F "\\\"" "{print \$2}"', container_flag=flag)
        if os_type == "Ubuntu":
            ret_c, _ = self.cmd("apt-get update", ignore_status=True, container_flag=flag)
            pkg_status_cmd = "dpkg -s "
            pkg_info_cmd = "apt show "
            pkg_install_cmd = "apt-get install -y "
            pkg_update_cmd = "apt-get upgrade -y "
        else:
            pkg_status_cmd = "rpm -q "
            pkg_info_cmd = "yum info "
            pkg_install_cmd = "yum install -y "
            pkg_update_cmd = "yum update -y "
        for rpm in rpms:
            cmdline = pkg_status_cmd + rpm
            ret_c, _ = self.cmd(cmdline, ignore_status=True, container_flag=flag)
            if ret_c is not 0:
                cmdline = pkg_info_cmd + rpm
                ret_c, _ = self.cmd(cmdline, ignore_status=True, container_flag=flag)
                if ret_c is not 0:
                    self.skip("%s is not available." % rpm)
                cmdline = pkg_install_cmd + rpm
                self.cmd(cmdline, container_flag=flag)
                rpm_flag[rpm] = 1
            else:
                cmdline = pkg_update_cmd + rpm
                self.cmd(cmdline, container_flag=flag)
                rpm_flag[rpm] = 0
        return rpm_flag

    def rpm_uninstall(self, rpm_flag, flag=0):
        _, os_type = self.cmd('cat /etc/os-release | grep -w NAME= | awk -F "\\\"" "{print \$2}"', container_flag=flag)
        if os_type == "Ubuntu":
            pkg_status_cmd = "apt-get remove -y "
        else:
            pkg_status_cmd = "yum erase -y "
        rpms = ""
        for rpm in rpm_flag.keys():
            if rpm_flag[rpm] is 1:
                rpms = rpms + " " + rpm
        if rpms != "":
            cmdline = pkg_status_cmd + rpm
            self.cmd(cmdline)
